# SPDX-FileCopyrightText: Copyright (c) 2018-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio

import numpy as np
import omni
from isaacsim.core.simulation_manager import SimulationManager
from isaacsim.core.utils.stage import add_reference_to_stage, open_stage_async
from isaacsim.test.utils import TimedAsyncTestCase
from pxr import UsdPhysics


class ROS2TestCase(TimedAsyncTestCase):
    """Base test class that automatically times all test methods.

    This class extends TimedAsyncTestCase to add ROS2 specific setup and teardown.
    It also provides helper methods for creating and managing ROS2 resources with automatic cleanup.
    """

    async def setUp(self):
        """Set up test timing before each test method."""
        await super().setUp()
        self._timeline = omni.timeline.get_timeline_interface()

        # Initialize lists to track ROS2 resources for automatic cleanup
        self._ros2_nodes = []
        self._ros2_publishers = []
        self._ros2_subscribers = []

        import rclpy

        if not rclpy.ok():
            rclpy.init()

        SimulationManager.set_backend("numpy")
        SimulationManager.set_physics_sim_device("cpu")
        SimulationManager.enable_fabric(enable=False)
        await omni.kit.app.get_app().next_update_async()

    def create_node(self, node_name):
        """Create a ROS2 node and track it for automatic cleanup.

        Args:
            node_name: Name of the ROS2 node to create.

        Returns:
            The created ROS2 node.
        """
        import rclpy

        node = rclpy.create_node(node_name)
        self._ros2_nodes.append(node)
        return node

    def create_publisher(self, node, msg_type, topic_name, qos_profile=10):
        """Create a ROS2 publisher and track it for automatic cleanup.

        Args:
            node: The ROS2 node to create the publisher on.
            msg_type: The message type for the publisher.
            topic_name: The topic name to publish to.
            qos_profile: QoS profile for the publisher (default: 10).

        Returns:
            The created ROS2 publisher.
        """
        publisher = node.create_publisher(msg_type, topic_name, qos_profile)
        self._ros2_publishers.append((node, publisher))
        return publisher

    def create_subscription(self, node, msg_type, topic_name, callback, qos_profile=10):
        """Create a ROS2 subscription and track it for automatic cleanup.

        Args:
            node: The ROS2 node to create the subscription on.
            msg_type: The message type for the subscription.
            topic_name: The topic name to subscribe to.
            callback: Callback function for received messages.
            qos_profile: QoS profile for the subscription (default: 10).

        Returns:
            The created ROS2 subscription.
        """
        subscription = node.create_subscription(msg_type, topic_name, callback, qos_profile)
        self._ros2_subscribers.append((node, subscription))
        return subscription

    def destroy_subscription(self, node, subscription):
        """Manually destroy a subscription and remove it from tracking.

        Args:
            node: The ROS2 node that owns the subscription.
            subscription: The subscription to destroy.
        """
        try:
            node.destroy_subscription(subscription)
            # Remove from tracking list to avoid double cleanup
            if (node, subscription) in self._ros2_subscribers:
                self._ros2_subscribers.remove((node, subscription))
        except Exception as e:
            print(f"Warning: Failed to destroy subscription: {e}")

    def destroy_publisher(self, node, publisher):
        """Manually destroy a publisher and remove it from tracking.

        Args:
            node: The ROS2 node that owns the publisher.
            publisher: The publisher to destroy.
        """
        try:
            node.destroy_publisher(publisher)
            # Remove from tracking list to avoid double cleanup
            if (node, publisher) in self._ros2_publishers:
                self._ros2_publishers.remove((node, publisher))
        except Exception as e:
            print(f"Warning: Failed to destroy publisher: {e}")

    async def tearDown(self):
        self._timeline.stop()
        await omni.kit.app.get_app().next_update_async()
        while omni.usd.get_context().get_stage_loading_status()[2] > 0:
            print("tearDown, assets still loading, waiting to finish...")
            await asyncio.sleep(1.0)

        # Clean up ROS2 resources in the correct order
        # First destroy publishers
        for node, publisher in self._ros2_publishers:
            try:
                node.destroy_publisher(publisher)
            except Exception as e:
                print(f"Warning: Failed to destroy publisher: {e}")

        # Then destroy subscribers
        for node, subscription in self._ros2_subscribers:
            try:
                node.destroy_subscription(subscription)
            except Exception as e:
                print(f"Warning: Failed to destroy subscription: {e}")

        # Finally destroy nodes
        for node in self._ros2_nodes:
            try:
                node.destroy_node()
            except Exception as e:
                print(f"Warning: Failed to destroy node: {e}")

        # Clear the tracking lists
        self._ros2_publishers.clear()
        self._ros2_subscribers.clear()
        self._ros2_nodes.clear()

        import rclpy

        if rclpy.ok():
            rclpy.shutdown()

        await super().tearDown()


def set_translate(prim, new_loc):
    from pxr import Gf, UsdGeom

    properties = prim.GetPropertyNames()
    if "xformOp:translate" in properties:
        translate_attr = prim.GetAttribute("xformOp:translate")

        translate_attr.Set(new_loc)
    elif "xformOp:translation" in properties:
        translation_attr = prim.GetAttribute("xformOp:translate")
        translation_attr.Set(new_loc)
    elif "xformOp:transform" in properties:
        transform_attr = prim.GetAttribute("xformOp:transform")
        matrix = prim.GetAttribute("xformOp:transform").Get()
        matrix.SetTranslateOnly(new_loc)
        transform_attr.Set(matrix)
    else:
        xform = UsdGeom.Xformable(prim)
        xform_op = xform.AddXformOp(UsdGeom.XformOp.TypeTransform, UsdGeom.XformOp.PrecisionDouble, "")
        xform_op.Set(Gf.Matrix4d().SetTranslate(new_loc))


def set_rotate(prim, rot_mat):
    from pxr import Gf, UsdGeom

    properties = prim.GetPropertyNames()
    if "xformOp:rotate" in properties:
        rotate_attr = prim.GetAttribute("xformOp:rotate")
        rotate_attr.Set(rot_mat)
    elif "xformOp:transform" in properties:
        transform_attr = prim.GetAttribute("xformOp:transform")
        matrix = prim.GetAttribute("xformOp:transform").Get()
        matrix.SetRotateOnly(rot_mat.ExtractRotation())
        transform_attr.Set(matrix)
    else:
        xform = UsdGeom.Xformable(prim)
        xform_op = xform.AddXformOp(UsdGeom.XformOp.TypeTransform, UsdGeom.XformOp.PrecisionDouble, "")
        xform_op.Set(Gf.Matrix4d().SetRotate(rot_mat))


async def add_cube(path, size, offset):
    from pxr import UsdGeom, UsdPhysics

    stage = omni.usd.get_context().get_stage()
    cubeGeom = UsdGeom.Cube.Define(stage, path)
    cubePrim = stage.GetPrimAtPath(path)

    cubeGeom.CreateSizeAttr(size)
    cubeGeom.AddTranslateOp().Set(offset)
    await omni.kit.app.get_app().next_update_async()  # Need this to avoid flatcache errors
    rigid_api = UsdPhysics.RigidBodyAPI.Apply(cubePrim)
    rigid_api.CreateRigidBodyEnabledAttr(True)
    UsdPhysics.CollisionAPI.Apply(cubePrim)

    return cubeGeom


async def add_carter(assets_root_path, prim_path="/Carter"):
    from pxr import Gf, PhysicsSchemaTools

    add_reference_to_stage(assets_root_path + "/Isaac/Robots/NVIDIA/Carter/carter_v1_physx_lidar.usd", prim_path)
    stage = omni.usd.get_context().get_stage()
    PhysicsSchemaTools.addGroundPlane(stage, "/World/groundPlane", "Z", 1500, Gf.Vec3f(0, 0, -0.25), Gf.Vec3f(0.5))
    await omni.kit.app.get_app().next_update_async()
    return prim_path


async def add_carter_ros(assets_root_path, prim_path="/Carter"):
    from pxr import Gf, PhysicsSchemaTools

    add_reference_to_stage(assets_root_path + "/Isaac/Samples/ROS2/Robots/Carter_ROS.usd", prim_path)
    await omni.kit.app.get_app().next_update_async()
    # Disabling cameras by default
    import omni.graph.core as og

    ros_cameras_graph_path = prim_path + "/ROS_Cameras"

    prims_to_disable = [
        ros_cameras_graph_path + "/isaac_create_render_product_left.inputs:enabled",
        ros_cameras_graph_path + "/isaac_create_render_product_right.inputs:enabled",
        ros_cameras_graph_path + "/ros2_camera_helper.inputs:enabled",
        ros_cameras_graph_path + "/ros2_camera_helper_01.inputs:enabled",
        ros_cameras_graph_path + "/ros2_camera_helper_03.inputs:enabled",
        ros_cameras_graph_path + "/ros2_camera_helper_04.inputs:enabled",
        ros_cameras_graph_path + "/ros2_camera_info_helper.inputs:enabled",
    ]
    for prim_to_disable in prims_to_disable:
        og.Controller.set(og.Controller.attribute(prim_to_disable), False)

    stage = omni.usd.get_context().get_stage()

    PhysicsSchemaTools.addGroundPlane(stage, "/World/groundPlane", "Z", 1500, Gf.Vec3f(0, 0, -0.25), Gf.Vec3f(0.5))
    await omni.kit.app.get_app().next_update_async()
    return prim_path


async def add_nova_carter_ros(assets_root_path):
    (result, error) = await open_stage_async(assets_root_path + "/Isaac/Samples/ROS2/Robots/Nova_Carter_ROS.usd")
    await omni.kit.app.get_app().next_update_async()


async def add_franka(assets_root_path):
    (result, error) = await open_stage_async(assets_root_path + "/Isaac/Robots/FrankaRobotics/FrankaPanda/franka.usd")


def get_qos_profile():
    from rclpy.qos import QoSHistoryPolicy, QoSProfile, QoSReliabilityPolicy

    return QoSProfile(reliability=QoSReliabilityPolicy.BEST_EFFORT, history=QoSHistoryPolicy.KEEP_LAST, depth=1)


def fields_to_dtype(fields, point_step):
    """Convert a list of PointFields to a numpy record datatype."""
    DUMMY_FIELD_PREFIX = "__"

    from sensor_msgs.msg import PointField

    # mappings between PointField types and numpy types
    type_mappings = [
        (PointField.INT8, np.dtype("int8")),
        (PointField.UINT8, np.dtype("uint8")),
        (PointField.INT16, np.dtype("int16")),
        (PointField.UINT16, np.dtype("uint16")),
        (PointField.INT32, np.dtype("int32")),
        (PointField.UINT32, np.dtype("uint32")),
        (PointField.FLOAT32, np.dtype("float32")),
        (PointField.FLOAT64, np.dtype("float64")),
    ]
    pftype_to_nptype = dict(type_mappings)
    nptype_to_pftype = dict((nptype, pftype) for pftype, nptype in type_mappings)

    # sizes (in bytes) of PointField types
    pftype_sizes = {
        PointField.INT8: 1,
        PointField.UINT8: 1,
        PointField.INT16: 2,
        PointField.UINT16: 2,
        PointField.INT32: 4,
        PointField.UINT32: 4,
        PointField.FLOAT32: 4,
        PointField.FLOAT64: 8,
    }

    offset = 0
    np_dtype_list = []
    for f in fields:
        while offset < f.offset:
            # might be extra padding between fields
            np_dtype_list.append(("%s%d" % (DUMMY_FIELD_PREFIX, offset), np.uint8))
            offset += 1

        dtype = pftype_to_nptype[f.datatype]
        if f.count != 1:
            dtype = np.dtype((dtype, f.count))

        np_dtype_list.append((f.name, dtype))
        offset += pftype_sizes[f.datatype] * f.count

    # might be extra padding between points
    while offset < point_step:
        np_dtype_list.append(("%s%d" % (DUMMY_FIELD_PREFIX, offset), np.uint8))
        offset += 1

    return np_dtype_list


def set_joint_drive_parameters(joint_path, joint_type, drive_type, target_value, stiffness=None, damping=None):
    stage = omni.usd.get_context().get_stage()
    drive = UsdPhysics.DriveAPI.Get(stage.GetPrimAtPath(joint_path), joint_type)

    if not drive:
        # if no drive exists, return false
        return False

    if drive_type == "position":
        if not drive.GetTargetPositionAttr():
            drive.CreateTargetPositionAttr(target_value)
        else:
            drive.GetTargetPositionAttr().Set(target_value)
    elif drive_type == "velocity":
        if not drive.GetTargetVelocityAttr():
            drive.CreateTargetVelocityAttr(target_value)
        else:
            drive.GetTargetVelocityAttr().Set(target_value)

    if stiffness is not None:
        if not drive.GetStiffnessAttr():
            drive.CreateStiffnessAttr(stiffness)
        else:
            drive.GetStiffnessAttr().Set(stiffness)

    if damping is not None:
        if not drive.GetDampingAttr():
            drive.CreateDampingAttr(damping)
        else:
            drive.GetDampingAttr().Set(damping)
